#!/usr/bin/env python3
"""
Generate bit-depth-optimized CIE 1931 lookup tables for ESP32-HUB75-MatrixPanel-DMA

Based on:
- https://ledshield.wordpress.com/2012/11/13/led-brightness-to-your-eye-gamma-correction-no/
- https://gist.github.com/mathiasvr/19ce1d7b6caeab230934080ae1f1380e

Generates native LUTs for common bit depths to eliminate runtime conversion overhead
and improve visual quality by avoiding double-quantization artifacts.
"""

import os
from collections import Counter


def cie1931(L):
    """
    CIE 1931 lightness formula

    Args:
        L: Lightness value (0-100 range)

    Returns:
        Luminance value (0-1 range)
    """
    if L <= 8:
        return L / 902.3
    else:
        return ((L + 16.0) / 116.0) ** 3


def generate_lut(bit_depth):
    """
    Generate CIE 1931 lookup table for specific bit depth

    Args:
        bit_depth: Target bit depth (6-12)

    Returns:
        List of 256 values scaled to bit_depth range
    """
    max_val = (1 << bit_depth) - 1
    lut = []

    for i in range(256):
        # Normalize input to 0-1 range, then scale to 0-100 for CIE formula
        L = (i / 255.0) * 100.0
        # Apply CIE 1931 lightness curve
        Y = cie1931(L)
        # Scale to target bit depth and round
        value = round(Y * max_val)
        lut.append(value)

    return lut


def format_lut_as_c_array(lut, bit_depth, values_per_line=16):
    """
    Format LUT as C array with proper type

    Args:
        lut: List of LUT values
        bit_depth: Bit depth (determines uint8_t vs uint16_t)
        values_per_line: Number of values per line in output

    Returns:
        Tuple of (data_type, formatted_string)
    """
    dtype = f"uint{8 if bit_depth <= 8 else 16}_t"
    lines = []

    for i in range(0, len(lut), values_per_line):
        chunk = lut[i:i+values_per_line]
        line = ", ".join(f"{v:5d}" for v in chunk)
        lines.append(f"  {line}")

    return dtype, ",\n".join(lines) + ","


def validate_lut(lut, bit_depth):
    """
    Validate LUT correctness and print statistics

    Args:
        lut: Generated LUT
        bit_depth: Target bit depth

    Returns:
        True if validation passes, raises AssertionError otherwise
    """
    max_val = (1 << bit_depth) - 1

    # Check range
    assert lut[0] == 0, f"{bit_depth}-bit LUT[0] should be 0, got {lut[0]}"
    assert lut[255] == max_val, f"{bit_depth}-bit LUT[255] should be {max_val}, got {lut[255]}"

    # Check monotonicity (values should never decrease)
    for i in range(255):
        assert lut[i] <= lut[i+1], f"{bit_depth}-bit LUT not monotonic at index {i}: {lut[i]} > {lut[i+1]}"

    # Calculate statistics
    unique = len(set(lut))
    counts = Counter(lut)
    worst_collision = max(counts.values())

    print(f"  {bit_depth:2d}-bit: {unique:3d}/256 unique values (range 0-{max_val:4d}), "
          f"worst collision: {worst_collision:2d} inputs -> 1 output")

    return True


def generate_header():
    """
    Generate complete cie_luts.h header file with all LUT tables

    Returns:
        String containing complete header file content
    """

    header = """// Auto-generated file - DO NOT EDIT manually
// Generated by tools/generate_cie_luts.py
//
// CIE 1931 lightness lookup tables optimized for different bit depths
//
// This file contains pre-computed lookup tables that map 8-bit RGB input values (0-255)
// to perceptually-linear output values at various bit depths. Native bit-depth lookup
// tables eliminate runtime conversion overhead and improve gradient quality.
//
// Based on:
//   - https://ledshield.wordpress.com/2012/11/13/led-brightness-to-your-eye-gamma-correction-no/
//   - https://gist.github.com/mathiasvr/19ce1d7b6caeab230934080ae1f1380e
//
// Formula: CIE 1931 lightness curve
//   For L ≤ 8:    Y = L / 902.3
//   For L > 8:    Y = ((L + 16) / 116)³
//   Where L = input brightness (0-100), Y = output luminance (0-1)

#pragma once

#ifndef NO_CIE1931

#include <stdint.h>

"""

    # Generate LUTs for common bit depths
    bit_depths = [6, 7, 8, 10, 12]

    print("\n=== Generating CIE 1931 Lookup Tables ===\n")

    for depth in bit_depths:
        lut = generate_lut(depth)
        dtype, formatted_lut = format_lut_as_c_array(lut, depth)

        # Validate and print statistics
        validate_lut(lut, depth)

        max_val = (1 << depth) - 1
        header += f"""
#if !defined(PIXEL_COLOR_DEPTH_BITS) || PIXEL_COLOR_DEPTH_BITS == {depth}
// {depth}-bit CIE 1931 lookup table
// Maps 8-bit input (0-255) to {depth}-bit output (0-{max_val})
static const {dtype} lumConvTab_{depth}bit[256] = {{
{formatted_lut}
}};
#endif
"""

    # Add compile-time selection logic
    header += """
// ============================================================================
// Compile-time selection of appropriate LUT based on PIXEL_COLOR_DEPTH_BITS
// ============================================================================

#if !defined(PIXEL_COLOR_DEPTH_BITS)
  #define PIXEL_COLOR_DEPTH_BITS 8
#endif

#if PIXEL_COLOR_DEPTH_BITS == 6
  #define lumConvTab lumConvTab_6bit
  #define LUT_NATIVE_BIT_DEPTH 1
#elif PIXEL_COLOR_DEPTH_BITS == 7
  #define lumConvTab lumConvTab_7bit
  #define LUT_NATIVE_BIT_DEPTH 1
#elif PIXEL_COLOR_DEPTH_BITS == 8
  #define lumConvTab lumConvTab_8bit
  #define LUT_NATIVE_BIT_DEPTH 1
#elif PIXEL_COLOR_DEPTH_BITS == 10
  #define lumConvTab lumConvTab_10bit
  #define LUT_NATIVE_BIT_DEPTH 1
#elif PIXEL_COLOR_DEPTH_BITS == 12
  #define lumConvTab lumConvTab_12bit
  #define LUT_NATIVE_BIT_DEPTH 1
#else
  // Fallback for non-standard bit depths (5, 9, 11, etc.)
  // Uses 12-bit LUT with runtime shift+round conversion
  #define lumConvTab lumConvTab_12bit
  #define LUT_NATIVE_BIT_DEPTH 0
  #warning "Using non-native CIE LUT bit depth - using 12-bit LUT with runtime conversion"
#endif

#endif // NO_CIE1931
"""

    print("\n[OK] All LUTs validated successfully\n")

    return header


def main():
    """Main entry point"""
    # Generate header content
    header_content = generate_header()

    # Determine output path (relative to script location)
    script_dir = os.path.dirname(os.path.abspath(__file__))
    output_path = os.path.join(script_dir, "..", "src", "cie_luts.h")
    output_path = os.path.normpath(output_path)

    # Write to file
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(header_content)

    print(f"[OK] Generated: {output_path}")
    print(f"\nNative LUTs available for: 6, 7, 8, 10, 12-bit depths")
    print(f"Other depths will use 12-bit LUT with runtime conversion\n")


if __name__ == "__main__":
    main()
